/*
 * Copyright 2018- The Pixie Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 */

#include <gflags/gflags.h>
#include <gmock/gmock.h>
#include <memory>
#include <numeric>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

#include "src/carnot/funcs/builtins/ml_ops.h"
#include "src/carnot/udf/test_utils.h"
#include "src/common/base/base.h"

#include "src/carnot/exec/ml/eigen_test_utils.h"

DEFINE_string(sentencepiece_dir, "", "Path to sentencepiece.proto");
DEFINE_string(embedding_dir, "", "Path to embedding.proto");

namespace px {
namespace carnot {
namespace builtins {

using ::px::carnot::udf::FunctionContext;

std::string write_vector_to_json(const Eigen::VectorXf& vector) {
  rapidjson::StringBuffer sb;
  rapidjson::Writer<rapidjson::StringBuffer> writer(sb);
  writer.StartArray();
  for (int i = 0; i < vector.rows(); i++) {
    writer.Double(vector(i));
  }
  writer.EndArray();
  return sb.GetString();
}

TEST(KMeans, basic) {
  int k = 3;
  int d = 2;

  auto kmeans_uda_tester = udf::UDATester<KMeansUDA>(d);

  Eigen::MatrixXf expected_centroids = kmeans_expected_centroids();
  Eigen::MatrixXf points = kmeans_test_data();

  for (int i = 0; i < points.rows(); i++) {
    auto inp = write_vector_to_json(points(i, Eigen::all).transpose());
    kmeans_uda_tester.ForInput(inp, k);
  }

  auto res = kmeans_uda_tester.Result();
  px::carnot::exec::ml::KMeans kmeans(k);
  kmeans.FromJSON(res);
  EXPECT_THAT(kmeans.centroids(), UnorderedRowsAre(expected_centroids, 0.1));
}

TEST(SentencePiece, basic) {
  auto udf_tester = udf::UDFTester<SentencePieceUDF>(FLAGS_sentencepiece_dir);
  udf_tester.ForInput("Test 123!");
  // This test is just a sanity check to see that the sentencepiece UDF runs.
  // If the model changes this test will almost certainly fail.
  udf_tester.Expect("[4,197,803,195,16,5001]");
}

TEST(Transformer, basic) {
  auto pool = exec::ml::ModelPool::Create();
  auto ctx = std::make_unique<FunctionContext>(nullptr, pool.get());
  auto udf_tester = udf::UDFTester<TransformerUDF>(std::move(ctx), FLAGS_embedding_dir);
  udf_tester.ForInput("[4,197,803,195,16,5001]");
  // This test is just a sanity check to see that the transformer UDF runs.
  // If the model changes this test will fail.
  udf_tester.Expect(
      "[8.423064231872559,1.762765645980835,17.635025024414064,15.878694534301758,-1."
      "3718032836914063,8.416397094726563,3.800554037094116,14.292364120483399,8.203149795532227,"
      "15.849493026733399,13.610809326171875,1.56343674659729,-0.3372507393360138,3."
      "7281088829040529,-0.05103045701980591,-18.241056442260743,-0.6284072399139404,7."
      "207020282745361,1.1521596908569337,1.0471590757369996,-0.24791300296783448,1."
      "916718602180481,1.257470965385437,1.1642893552780152,3.5091605186462404,2.920778751373291,-"
      "0.2878648042678833,14.827808380126954,-18.584875106811525,5.088066101074219,3."
      "0825138092041017,1.0002727508544922,2.281430244445801,13.94554328918457,2.6003968715667726,"
      "5.258618354797363,1.8324857950210572,-0.05988609790802002,0.14812558889389039,-0."
      "7637635469436646,0.24192380905151368,-0.09021991491317749,-0.07128120213747025,-4."
      "9098358154296879,2.119394540786743,13.733222007751465,10.318382263183594,-16.00957489013672,"
      "1.1407907009124756,-4.009923934936523,1.297436237335205,0.9833167791366577,-0."
      "029665619134902955,1.14364492893219,19.570220947265626,-6.751187324523926,4.048876762390137,"
      "11.755240440368653,-0.18544794619083405,-1.025405764579773,-0.3492504358291626,23."
      "207979202270509,-0.143155038356781,5.322646141052246,0.0,0.0,0.0,1.1513252258300782,0."
      "358573853969574,0.10737401247024536,0.4662875831127167,3.5053648948669435,1."
      "1701009273529053,1.101596713066101,-0.008506953716278077,-0.040919721126556399,0."
      "22259575128555299,-0.08148394525051117,0.5089952945709229,1.4176372289657593,1."
      "401298464324817e-45,0.15878546237945558,1.1605896949768067,0.4044678211212158,0."
      "9227162003517151,1.0621774196624756,0.4838687777519226,1.5155175924301148,0."
      "4667109549045563,1.2825508117675782,1.0523432493209839,1.8215129375457764,2.253741979598999,"
      "0.480151891708374,0.5978711843490601,0.664787232875824,0.36510878801345827,1."
      "373420000076294,1.1708780527114869,0.8425347805023193,0.34504178166389468,1."
      "1472735404968262,0.4935304522514343,-0.836276113986969,-1.4642415046691895,1."
      "1901633739471436,0.15823155641555787,1.858086109161377,1.2436692714691163,0."
      "3054755628108978,0.3168545365333557,0.8003208041191101,1.1527131795883179,0."
      "8839558362960815,0.38009387254714968,0.3699653148651123,0.06293665617704392,1."
      "1082345247268677,1.426318645477295,1.9110982418060303,0.8831897377967835,2.389556646347046,-"
      "0.12238005548715592,0.9231411218643189,0.7139260172843933,0.07743673026561737,-0."
      "09491483867168427,1.3351469039916993,1.4627797603607178,-1.4475688934326172,0."
      "31976258754730227,0.7541965842247009,0.11952435970306397,-0.5786635875701904,-0."
      "03774666786193848,3.167584180831909,0.8516446352005005,0.4984332323074341,-0."
      "6286579370498657,-0.650283694267273,0.06163284182548523,-0.4208141565322876,0."
      "23980867862701417,0.6349300146102905,1.0345019102096558,-0.6285422444343567,0."
      "7422002553939819,0.1268576681613922,0.6671826839447022,0.716931939125061,0.1698441505432129,"
      "1.1895318031311036,0.13287216424942017,0.6991399526596069,0.7395703792572022,1."
      "1269493103027344,1.5145041942596436,0.25194573402404787,0.2485913634300232,0."
      "35495102405548098,0.1102229654788971,0.9173923134803772,0.7953489422798157,0."
      "45405471324920657,0.14721998572349549,0.6640918850898743,0.34429192543029787,-1."
      "4020744562149048,-1.473140001296997,0.9868255853652954,0.11252009123563767,1."
      "3520417213439942,0.6468302011489868,-0.06478893756866455,-0.13571488857269288,-0."
      "03130674362182617,0.8152022957801819,0.5368598699569702,0.16718798875808717,0."
      "07826703786849976,0.0529685914516449,0.8314865827560425,0.7944748401641846,1."
      "3612890243530274,0.5407318472862244,1.8429951667785645,-0.12639638781547547,0."
      "6101526618003845,0.5273962020874023,-0.18153735995292664,-0.10076280683279038,1."
      "170333743095398,1.5170931816101075,-1.4206146001815797,0.39199161529541018,0."
      "8010396957397461,0.1477212905883789,-0.4977424144744873,0.02170652151107788,3."
      "2074270248413088,0.8892079591751099,0.569579005241394,-0.5555083751678467,-0."
      "5784064531326294,0.08061912655830383,-0.3807886242866516,0.271560400724411,0."
      "727253794670105,1.0666420459747315,-0.5356735587120056,0.7915511727333069,0."
      "1596030294895172,0.6973240375518799,0.757655143737793,0.2068847417831421,1.2279833555221558,"
      "0.1722499132156372,0.7679558992385864,0.7764633893966675,1.2088762521743775,1."
      "6017005443572999,0.27886366844177248,0.2897905111312866,0.3914976119995117,0."
      "1402878761291504,0.9711828231811523,0.8396442532539368,0.4998776912689209,0."
      "17055395245552064,0.7210853695869446,0.36189526319503786,-1.3353359699249268,-1."
      "4720903635025025,1.010810136795044,0.11791196465492249,1.4117319583892823,0."
      "7172300219535828,-0.021114587783813478,-0.0823323130607605,0.06678760051727295,0."
      "8550132513046265,0.5778014659881592,0.1923011839389801,0.11267414689064026,0."
      "05414436757564545,0.8641302585601807,0.8690036535263062,1.4261415004730225,0.58112633228302,"
      "1.9074645042419434,-0.12592265009880067,0.6470710635185242,0.5493981838226318,-0."
      "15099024772644044,-0.10007300972938538,1.1897741556167603]");
}

}  // namespace builtins
}  // namespace carnot
}  // namespace px
